ARG IMAGE_NAME
FROM ${IMAGE_NAME}:{{ cuda.version.major }}.{{ cuda.version.minor }}-runtime-{{ cuda.os.distro }}{{ cuda.os.version }}{{ cuda.image_tag_suffix if "image_tag_suffix" in cuda }} as base

{% if "x86_64" in cuda %}
FROM base as base-amd64

ENV NV_CUDA_LIB_VERSION {{ cuda.x86_64.components.libraries.version }}
ENV NV_CUDA_CUDART_DEV_VERSION {{ cuda.x86_64.components.cudart_dev.version }}
ENV NV_NVML_DEV_VERSION {{ cuda.x86_64.components.nvml_dev.version }}
ENV NV_LIBCUSPARSE_DEV_VERSION {{ cuda.x86_64.components.libcusparse_dev.version }}
ENV NV_LIBNPP_DEV_VERSION {{ cuda.x86_64.components.libnpp_dev.version }}
{% if "libcublas_dev" in cuda.x86_64.components %}
    {% if cuda.version.major_minor in ["10.1", "10.2"] %}
        {% set suffix = "" %}
    {% endif %}
{% endif %}
    {% if cuda.version.major_minor in ["8.0", "9.0", "9.1", "9.2", "10.0"] %}
ENV NV_LIBCUBLAS_DEV_PACKAGE_NAME cuda-cublas-dev-{{ cuda.version.major }}-{{ cuda.version.minor }}
    {% else %}
ENV NV_LIBCUBLAS_DEV_PACKAGE_NAME libcublas-dev{{ suffix }}
    {% endif %}

ENV NV_LIBCUBLAS_DEV_VERSION {{ cuda.x86_64.components.libcublas_dev.version }}
ENV NV_LIBCUBLAS_DEV_PACKAGE ${NV_LIBCUBLAS_DEV_PACKAGE_NAME}=${NV_LIBCUBLAS_DEV_VERSION}

{% if "libnccl2" in cuda.x86_64.components and cuda.x86_64.components.libnccl2 and "libnccl2_dev" in cuda.x86_64.components and cuda.x86_64.components.libnccl2_dev %}
ENV NV_LIBNCCL_DEV_PACKAGE_NAME libnccl-dev
ENV NV_LIBNCCL_DEV_VERSION {{ cuda.x86_64.components.libnccl2_dev.version }}
ENV NCCL_VERSION ${NV_LIBNCCL_DEV_VERSION}
ENV NV_LIBNCCL_DEV_PACKAGE ${NV_LIBNCCL_DEV_PACKAGE_NAME}=${NV_LIBNCCL_DEV_VERSION}+cuda{{ cuda.version.major }}.{{ cuda.version.minor }}

{% endif -%}
{% if cuda.version.major | int == 8 %}
ENV NV_CUDA_RUNTIME_DEV_VERSION {{ cuda.x86_64.components.cudart_dev.version }}
ENV NV_CUDA_RUNTIME_DEV_PACKAGE cuda-core-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_CUDA_RUNTIME_DEV_VERSION}
ENV NV_CUDA_CORE_VERSION {{ cuda.x86_64.components.cudart_dev.version }}
ENV NV_CUDA_CORE_PACKAGE cuda-core-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_CUDA_CORE_VERSION}
ENV NV_CUDA_DRIVER_DEV_VERSION {{ cuda.x86_64.components.cudart_dev.version }}
ENV NV_CUDA_DRIVER_DEV_PACKAGE cuda-driver-dev-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_CUDA_DRIVER_DEV_VERSION}
ENV NV_CUDA_MISC_HEADERS_VERSION {{ cuda.x86_64.components.cudart.version }}
ENV NV_CUDA_MISC_HEADERS_PACKAGE cuda-misc-headers-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_CUDA_MISC_HEADERS_VERSION}
ENV NV_CUDA_NVRTC_DEV_VERSION {{ cuda.x86_64.components.nvrtc_dev.version }}
ENV NV_CUDA_NVRTC_DEV_PACKAGE cuda-nvrtc-dev-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_CUDA_NVRTC_DEV_VERSION}
ENV NV_CUDA_NVGRAPH_DEV_VERSION {{ cuda.x86_64.components.nvgraph_dev.version }}
ENV NV_CUDA_NVGRAPH_DEV_PACKAGE cuda-nvgraph-dev-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_CUDA_NVGRAPH_DEV_VERSION}
ENV NV_LIBCUSOLVER_DEV_VERSION {{ cuda.x86_64.components.libcusolver_dev.version }}
ENV NV_LIBCUSOLVER_DEV_PACKAGE cuda-cusolver-dev-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_LIBCUSOLVER_DEV_VERSION}
ENV NV_LIBCUBLAS_DEV_VERSION {{ cuda.x86_64.components.libcublas_dev.version }}
ENV NV_LIBCUBLAS_DEV_PACKAGE cuda-cublas-dev-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_LIBCUBLAS_DEV_VERSION}
ENV NV_LIBCUFFT_DEV_VERSION {{ cuda.x86_64.components.libcufft_dev.version }}
ENV NV_LIBCUFFT_DEV_PACKAGE cuda-cufft-dev-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_LIBCUFFT_DEV_VERSION}
ENV NV_LIBCURAND_DEV_VERSION {{ cuda.x86_64.components.libcurand_dev.version }}
ENV NV_LIBCURAND_DEV_PACKAGE cuda-curand-dev-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_LIBCURAND_DEV_VERSION}
{# ENV NV_LIBCUSPARSE_DEV_VERSION {{ cuda.x86_64.components.libcusparse_dev.version }} #}
ENV NV_LIBCUSPARSE_DEV_PACKAGE cuda-cusparse-dev-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_LIBCUSPARSE_DEV_VERSION}
ENV NV_LIBNPP_DEV_VERSION {{ cuda.x86_64.components.libnpp_dev.version }}
ENV NV_LIBNPP_DEV_PACKAGE cuda-npp-dev-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_LIBNPP_DEV_VERSION}
{% endif %}

{% endif %}
{% if "ppc64le" in cuda %}
FROM base as base-ppc64le

ENV NV_CUDA_LIB_VERSION {{ cuda.x86_64.components.libraries.version }}
ENV NV_CUDA_CUDART_DEV_VERSION {{ cuda.ppc64le.components.cudart_dev.version }}
ENV NV_NVML_DEV_VERSION {{ cuda.ppc64le.components.nvml_dev.version }}
ENV NV_LIBCUSPARSE_DEV_VERSION {{ cuda.ppc64le.components.libcusparse_dev.version }}
ENV NV_LIBNPP_DEV_VERSION {{ cuda.ppc64le.components.libnpp_dev.version }}
{% if "libcublas_dev" in cuda.ppc64le.components %}
    {% if cuda.version.major_minor in ["10.1", "10.2"] %}
        {% set suffix = "" %}
    {% endif %}
{% endif %}
    {% if cuda.version.major_minor in ["8.0", "9.0", "9.1", "9.2", "10.0"] %}
ENV NV_LIBCUBLAS_DEV_PACKAGE_NAME cuda-cublas-dev-{{ cuda.version.major }}-{{ cuda.version.minor }}
    {% else %}
ENV NV_LIBCUBLAS_DEV_PACKAGE_NAME libcublas-dev{{ suffix }}
    {% endif %}
ENV NV_LIBCUBLAS_DEV_VERSION {{ cuda.ppc64le.components.libcublas_dev.version }}
ENV NV_LIBCUBLAS_DEV_PACKAGE ${NV_LIBCUBLAS_DEV_PACKAGE_NAME}=${NV_LIBCUBLAS_DEV_VERSION}

{% if "libnccl2" in cuda.ppc64le.components and cuda.ppc64le.components.libnccl2 and "libnccl2_dev" in cuda.ppc64le.components and cuda.ppc64le.components.libnccl2_dev %}
ENV NV_LIBNCCL_DEV_PACKAGE_NAME libnccl-dev
ENV NV_LIBNCCL_DEV_VERSION {{ cuda.ppc64le.components.libnccl2_dev.version }}
ENV NCCL_VERSION ${NV_LIBNCCL_DEV_VERSION}
ENV NV_LIBNCCL_DEV_PACKAGE ${NV_LIBNCCL_DEV_PACKAGE_NAME}=${NV_LIBNCCL_DEV_VERSION}+cuda{{ cuda.version.major }}.{{ cuda.version.minor }}

{% endif -%}

{% if cuda.version.major | int == 8 %}
ENV NV_CUDA_RUNTIME_DEV_VERSION {{ cuda.ppc64le.components.cudart_dev.version }}
ENV NV_CUDA_RUNTIME_DEV_PACKAGE cuda-core-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_CUDA_RUNTIME_DEV_VERSION}
ENV NV_CUDA_CORE_VERSION {{ cuda.ppc64le.components.cudart_dev.version }}
ENV NV_CUDA_CORE_PACKAGE cuda-core-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_CUDA_CORE_VERSION}
ENV NV_CUDA_DRIVER_DEV_VERSION {{ cuda.ppc64le.components.cudart_dev.version }}
ENV NV_CUDA_DRIVER_DEV_PACKAGE cuda-driver-dev-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_CUDA_DRIVER_DEV_VERSION}
ENV NV_CUDA_MISC_HEADERS_VERSION {{ cuda.ppc64le.components.cudart.version }}
ENV NV_CUDA_MISC_HEADERS_PACKAGE cuda-misc-headers-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_CUDA_MISC_HEADERS_VERSION}
ENV NV_CUDA_NVRTC_DEV_VERSION {{ cuda.ppc64le.components.nvrtc_dev.version }}
ENV NV_CUDA_NVRTC_DEV_PACKAGE cuda-nvrtc-dev-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_CUDA_NVRTC_DEV_VERSION}
ENV NV_CUDA_NVGRAPH_DEV_VERSION {{ cuda.ppc64le.components.nvgraph_dev.version }}
ENV NV_CUDA_NVGRAPH_DEV_PACKAGE cuda-nvgraph-dev-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_CUDA_NVGRAPH_DEV_VERSION}
ENV NV_LIBCUSOLVER_DEV_VERSION {{ cuda.ppc64le.components.libcusolver_dev.version }}
ENV NV_LIBCUSOLVER_DEV_PACKAGE cuda-cusolver-dev-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_LIBCUSOLVER_DEV_VERSION}
ENV NV_LIBCUBLAS_DEV_VERSION {{ cuda.ppc64le.components.libcublas_dev.version }}
ENV NV_LIBCUBLAS_DEV_PACKAGE cuda-cublas-dev-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_LIBCUBLAS_DEV_VERSION}
ENV NV_LIBCUFFT_DEV_VERSION {{ cuda.ppc64le.components.libcufft_dev.version }}
ENV NV_LIBCUFFT_DEV_PACKAGE cuda-cufft-dev-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_LIBCUFFT_DEV_VERSION}
ENV NV_LIBCURAND_DEV_VERSION {{ cuda.ppc64le.components.libcurand_dev.version }}
ENV NV_LIBCURAND_DEV_PACKAGE cuda-curand-dev-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_LIBCURAND_DEV_VERSION}
{# ENV NV_LIBCUSPARSE_DEV_VERSION {{ cuda.ppc64le.components.libcusparse_dev.version }} #}
ENV NV_LIBCUSPARSE_DEV_PACKAGE cuda-cusparse-dev-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_LIBCUSPARSE_DEV_VERSION}
{# ENV NV_LIBNPP_DEV_VERSION {{ cuda.ppc64le.components.libnpp_dev.version }} #}
ENV NV_LIBNPP_DEV_PACKAGE cuda-npp-dev-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_LIBNPP_DEV_VERSION}
{% endif %}
{% endif -%}

FROM base-${TARGETARCH}

ARG TARGETARCH
LABEL maintainer "NVIDIA CORPORATION <cudatools@nvidia.com>"

RUN apt-get update && apt-get install -y --no-install-recommends \
    cuda-nvml-dev-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_NVML_DEV_VERSION} \
    cuda-command-line-tools-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_CUDA_LIB_VERSION} \
{% if (cuda.version.major | int >= 9 and cuda.version.minor | int > 0) or (cuda.version.major | int > 9 and cuda.version.minor | int >= 0) %}
    {# cuda-nvprof-{{ nvprof_component_version if nvprof_component_version and cuda.version.major | int >= 10 else "$CUDA_PKG_VERSION" }} \ #}
    cuda-nvprof-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_CUDA_LIB_VERSION} \
{% endif %}
    {% if ((cuda.arch != "arm64" or cuda.arch == "ppc64le") and cuda.version.major | int <= 10) or cuda.version.build == 89 %}
    cuda-npp-dev-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_LIBNPP_DEV_VERSION} \
    {% else %}
    libnpp-dev-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_LIBNPP_DEV_VERSION} \
    {% endif %}
{% if cuda.version.major | int >= 9 and cuda.version.major | int >= 1 %}
    cuda-libraries-dev-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_CUDA_LIB_VERSION} \
    cuda-minimal-build-{{ cuda.version.major }}-{{ cuda.version.minor }}=${NV_CUDA_LIB_VERSION} \
{% endif %}
{% if cuda.version.major_minor == "8.0" %}
    ${NV_CUDA_CORE_PACKAGE} \
    ${NV_CUDA_MISC_HEADERS_PACKAGE} \
    ${NV_CUDA_NVRTC_DEV_PACKAGE} \
    ${NV_CUDA_NVGRAPH_DEV_PACKAGE} \
    ${NV_LIBCUSOLVER_DEV_PACKAGE} \
    ${NV_LIBCUBLAS_DEV_PACKAGE} \
    ${NV_LIBCUFFT_DEV_PACKAGE} \
    ${NV_LIBCURAND_DEV_PACKAGE} \
    ${NV_LIBCUSPARSE_DEV_PACKAGE} \
    ${NV_LIBNPP_DEV_PACKAGE} \
    ${NV_CUDA_RUNTIME_DEV_PACKAGE} \
    ${NV_CUDA_DRIVER_DEV_PACKAGE} \
{% else %}
    ${NV_LIBCUBLAS_DEV_PACKAGE} \
    ${NV_LIBNCCL_DEV_PACKAGE} \
{% endif %}
    && rm -rf /var/lib/apt/lists/*

# apt from auto upgrading the cublas package. See https://gitlab.com/nvidia/container-images/cuda/-/issues/88
RUN apt-mark hold ${NV_LIBCUBLAS_DEV_PACKAGE_NAME} ${NV_LIBNCCL_DEV_PACKAGE_NAME}

ENV LIBRARY_PATH /usr/local/cuda/lib64/stubs
